import numpy as np
import dataset
import online_logistic_regression
import copy

if __name__ == "__main__":

    # Split dataset into train/test
    X, y = dataset.read_libsvm_dataset("rcv1_test.binary", binary=True)
    num_examples = len(X)
    num_train_examples = int(num_examples/2)
    X_train = X[:num_train_examples]
    X_test = X[num_train_examples:]
    y_train = y[:num_train_examples]
    y_test = y[num_train_examples:]
    print("Finished processing dataset")

    # Parameters
    lr_init = 0.1
    l2_reg = 0.0001
    no_bias = True

    # Train classifier using online logistic regression on X_train, y_train
    model = online_logistic_regression.OnlineLogisticRegression(lr_init=lr_init, l2_reg=l2_reg, no_bias=no_bias)
    correct = 0.0
    total = 0.0
    for i in range(len(X_train)):
        example = X_train[i]
        label = y_train[i]
        prediction = model.predict(example)
        total += 1
        if (prediction == label):
            correct += 1
        model.update(example=example, label=label)
    print("Training accuracy: ", correct/total)

    # Compute test error using X_test, y_test
    correct = 0.0
    total = 0.0
    for i in range(len(X_test)):
        example = X_test[i]
        label = y_test[i]
        prediction = model.predict(example)
        total += 1
        if (prediction == label):
            correct += 1
    print("Test accuracy: ", correct/total)
    full_weight_accuracy = correct/total

    # Compute test error when using sparse prediction
    with open('results.txt', 'w') as results_file:

        # Write accuracy with full weights
        results_file.write("Full weights & ")
        results_file.write(str(full_weight_accuracy))
        results_file.write("\n")

        list_of_ks = [-1, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 10000, 20000, 40000]
        for i in range(len(list_of_ks)):
            K = list_of_ks[i]
            k_sparse_model = copy.deepcopy(model)
            ret = k_sparse_model.sparsify_top_k(k=K)
            if (ret == -1):
                continue

            # Test error of k_sparse_model
            correct = 0.0
            total = 0.0
            for j in range(len(X_test)):
                example = X_test[j]
                label = y_test[j]
                prediction = k_sparse_model.predict(example)
                total += 1
                if (prediction == label):
                    correct += 1
            print("Test accuracy using top ", K, " weights: ", correct/total)

            # Write to results_file in Latex format
            results_file.write(str(K))
            results_file.write(" & ")
            results_file.write(str(correct/total))
            results_file.write("\n")